#!/usr/bin/env python3
import latticex.rosetta as rtt  # difference from tensorflow
import math
import os
import csv
import tensorflow as tf
import numpy as np
from util import read_dataset
from test_fw import mt_unit_test_fw
import time

def test(task_id):
  rtt.py_protocol_handler.set_loglevel(0)
  np.set_printoptions(suppress=True)
  
  os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
  
  np.random.seed(0)
  
  EPOCHES = 10
  BATCH_SIZE = 16
  learning_rate = 0.0002
  
  rtt.activate("SecureNN", task_id = task_id)
  node_id = rtt.get_current_node_id(task_id = task_id)
  dg = tf.Graph()
  with dg.as_default():
    
    # real data
    # ######################################## difference from tensorflow
    file_x = '../dsets/' + node_id + "/reg_train_x.csv"
    file_y = '../dsets/' + node_id + "/reg_train_y.csv"
    real_X, real_Y = rtt.PrivateDataset(data_owner=(
        0, 1), label_owner=1, task_id = task_id).load_data(file_x, file_y, header=None)
    # ######################################## difference from tensorflow
    DIM_NUM = real_X.shape[1]
    
    X = tf.placeholder(tf.float64, [None, DIM_NUM])
    Y = tf.placeholder(tf.float64, [None, 1])
    print(X)
    print(Y)
    
    # initialize W & b
    W = tf.Variable(tf.zeros([DIM_NUM, 1], dtype=tf.float64))
    b = tf.Variable(tf.zeros([1], dtype=tf.float64))
    print(W)
    print(b)
    
    # predict
    pred_Y = tf.matmul(X, W) + b
    print(pred_Y)
    
    # loss
    loss = tf.square(Y - pred_Y)
    loss = tf.reduce_mean(loss)
    print(loss)
    
    # optimizer
    train = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
    print(train)
    
    init = tf.global_variables_initializer()
    print(init)
    
    # ########### for test, reveal
    reveal_W = rtt.SecureReveal(W)
    reveal_b = rtt.SecureReveal(b)
    reveal_Y = rtt.SecureReveal(pred_Y)
    # ########### for test, reveal
    
    config = tf.ConfigProto(inter_op_parallelism_threads = 16, intra_op_parallelism_threads = 16) 
    with tf.Session(task_id = task_id, config = config) as sess:
        sess.run(init)
        #rW, rb = sess.run([reveal_W, reveal_b])
        #print("init weight:{} \nbias:{}".format(rW, rb))
    
        # train
        BATCHES = math.ceil(len(real_X) / BATCH_SIZE)
        for e in range(EPOCHES):
            for i in range(BATCHES):
                bX = real_X[(i * BATCH_SIZE): (i + 1) * BATCH_SIZE]
                bY = real_Y[(i * BATCH_SIZE): (i + 1) * BATCH_SIZE]
                print('*' * 80, task_id)
                sess.run(train, feed_dict={X: bX, Y: bY})
                print('#' * 80, task_id)
    
                j = e * BATCHES + i
                if j % 50 == 0 or (j == EPOCHES * BATCHES - 1 and j % 50 != 0):
                    pass
                    #rW, rb = sess.run([reveal_W, reveal_b])
                    #print("I,E,B:{:0>4d},{:0>4d},{:0>4d} weight:{} \nbias:{}".format(
                    #    j, e, i, rW, rb))
    
        # predict
        #Y_pred = sess.run(reveal_Y, feed_dict={X: real_X, Y: real_Y})
        #print("Y_pred:", Y_pred)
    
    print(rtt.get_perf_stats(pretty = True, task_id = task_id))
    rtt.deactivate(task_id = task_id)
for j in range(1):
  task_list = []
  for i in range(1):
      task_list.append(test)
  mt_unit_test_fw(task_list)
  print((str(j)+'-')*80)
  time.sleep(3)
